%Matlab code that implements the algorithm in
%"Zexi Huang, Transfer Learning for Community Detection. Bachelor thesis, University of Electronic Science and Technology of China, Chengdu, China, 2018."
%Zexi Huang
%9 June 2018

function transfer_learning_for_community_detection(outputFile,inputFile,targetLayerID,kCom,c,ck,T1,T2)
%outputFile: the file that contains the output community memberships in the target layer. Each line has the following format: nodeID communityID.
%inputFile: the file that contains the input connection information of the multiplex network. Each line shoule have the following format: layerID nodeID nodeID.
%targetLayerID: the layerID of the target layer that receives the transferred knowledge.
%kCom: number of communities to be detected in the targe layer.
%c: number of dimensions of common feature matrix.
%ck: number of dimensions of layer-specific feature matices.
%T1: maximum number of iterations of the learning algorithm.
%T2: maximum number of iterations of the clustering algorithm.

%Read the adjacency matrices.
A=read_graph(inputFile);

%Compute the utility matrices.
U=A;

%Learn the latent representations from utility matrices.
[B,X]=learning_algorithm(U,c,ck,T1);

%Cluster the combined latent representations to obtain the community partition.
C=clustering_algorithm(X,B{targetLayerID},kCom,T2);

%Store the found community memberships into file.
store_communities(C,outputFile);

    function A=read_graph(filename)
        %Read the adjacency matrices from file.
        
        input=load(filename);
        %Determine the number of layers
        N=max(input(:,1));
        %Determine the number of nodes.
        dimension=max(max([input(:,2),input(:,3)]));
        A=cell(1,N);
        
        bottom=1;
        for i=1:N
            %Sweep the layer.
            top=bottom;
            while input(bottom,1)==i
                bottom=bottom+1;
                if bottom>length(input)
                    break;
                end
            end
            A{i}=sparse(input(top:bottom-1,2),input(top:bottom-1,3),ones(bottom-top,1),dimension,dimension);
            
            %Force symmetry.
            A{i}=full(A{i});
            for j=1:dimension
                for k=1:dimension
                    if A{i}(j,k)==1
                        A{i}(k,j)=1;
                    end
                    %Remove self-loops.
                    if j==k
                        A{i}(j,k)=0;
                    end
                end
            end
            
            
        end
    end

    function [B,X]=learning_algorithm(U,c,ck,T1)
        %Learn the common feature matrix and layer specific matrices from the utility matrices.
        
        N=length(ck);%number of layers.
        n=length(U{1});%number of nodes.
        
        %Initialization
        X=rand(n,c);
        B=cell(1,N);
        for k=1:N
            B{k}=rand(n,ck(k));
        end
        
        %Multiplicative update.
        for i=1:T1
            
            
            %Update X.
            P=zeros(n,n);
            for k=1:N
                P=P+U{k}-(B{k}*B{k}');
            end
            Q=N*(X*X')*X;
            Y=multiplicative_update(X,P,Q);
            X=Y;
            
            
            %Update B(k)
            for k=1:N
                
                R=U{k}-(X*X');
                S=(B{k}*B{k}')*B{k};
                
                Y=multiplicative_update(B{k},R,S);
                B{k}=Y;
            end
            
        end
        
    end

    function Y=multiplicative_update(X,P,Q)
        %Detailed implementation of the multiplicative rules.
        
        [s,t]=size(X);
        Y=zeros(s,t);
        
        Pposi=0.5*(P+abs(P));
        Pnega=0.5*(-P+abs(P));
        M=Pposi*X;
        N=Pnega*X+Q;
        for ii=1:s
            for jj=1:t
                if M(ii,jj)>0&&N(ii,jj)>0
                    Y(ii,jj)=X(ii,jj)*sqrt(M(ii,jj)/N(ii,jj));
                else
                    Y(ii,jj)=0;
                end
            end
        end
        
        
    end

    function C=clustering_algorithm(X,Bk,kCom,T2)
        %Cluster the combined features [X B_K] into kCom clusters with k-means.
        
        S=cat(2,X,Bk);
        C=kmeans(S,kCom,'MaxIter',T2);
        
    end

    function store_communities(C,outputFile)
        %Store the community membership information into the specified output file.
        
        augC=cat(2,(1:length(C))',C);
        dlmwrite(outputFile,augC,'delimiter',' ');
    
    end



end

